Prediction of seizure-like activity in in vitro cortical populations

Our project will try to take advantage of machine learning methods to try to predict, with ~100 ms in advance, the onset of a “seizure". The recordings used in the development phase were from spiking activity of cortical neurons on microelectrode arrays chips with 256 electrodes. These recordings exhibit seizure-like activity, where a large fraction of the population synchronizes.

Epileptic seizure is a period of symptoms due to abnormally excessive or synchronous neuronal activity in the brain.

Requirements

First some modules need to be imported:

In [1]:
# These are the imports of the McsData module
import sys
sys.path.append('D:\\Programming\\McsDataManagement\\McsPyDataTools\\McsPyDataTools')
import McsPy.McsData
import McsPy.McsCMOS

# numpy is numpy ...
import numpy as np

# bokeh adds more interactivity to the plots within notebooks. Adds toolbar at the top-right corner of the plot.
# Allows zooming, panning and saving of the plot
import bokeh.io
from bokeh.io import show
from bokeh.layouts import column
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure
from bokeh.models.widgets import Slider
bokeh.io.output_notebook()


# tensorflow & other machine learning utilities
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize

# utilities
import time
Loading BokehJS ...

Import Data

Next we need to access the rawdata by initializing an instance of the RawData class from the McsData module by handing over the path to the file. The filepath points to the folder TestData within the folder, where this notebook resides.

To check if we got access to the file we can look at its contents by printing the info that got extracted when the RawData object was initialized.

In [2]:
channel_raw_data = McsPy.McsData.RawData('./TestData/cortical_same_chip/2019-02-11T11-42-58Chip_3_DIV17.h5')

print(channel_raw_data.comment)
print(channel_raw_data.date)
print(channel_raw_data.clr_date)
print(channel_raw_data.date_in_clr_ticks)
print(channel_raw_data.file_guid)
print(channel_raw_data.mea_name)
print(channel_raw_data.mea_sn)
print(channel_raw_data.mea_layout)
print(channel_raw_data.program_name)
print(channel_raw_data.program_version)
2019-02-11 11:42:58.197496
11 de fevereiro de 2019
636854821781974996
4eb6c972-3e6b-4c68-9115-bca9fb94b90a
256MEA for MEA2100-256

256MEA_MEA2100
Multi Channel Experimenter
2.9.3.18310

Visualize Analog Stream

The data of the stream can be found under .channel_data.

In [3]:
analog_stream_0 = channel_raw_data.recordings[0].analog_streams[0]
analog_stream_0_data = analog_stream_0.channel_data
#analog_stream_0_data = analog_stream_0_data[:,200000:300000]
electrodes = analog_stream_0.channel_data.shape[0]
print(analog_stream_0_data)
Recording_0 <HDF5 group "/Data/Recording_0" (1 members)>
Stream_0 <HDF5 group "/Data/Recording_0/AnalogStream/Stream_0" (3 members)>
ChannelData <HDF5 dataset "ChannelData": shape (252, 3000000), type "<i4">
ChannelDataTimeStamps <HDF5 dataset "ChannelDataTimeStamps": shape (1, 3), type "<i8">
InfoChannel <HDF5 dataset "InfoChannel": shape (252,), type "|V108">
<HDF5 dataset "ChannelData": shape (252, 3000000), type "<i4">
In [4]:
def modify_doc(doc):
    electrode_index = 0
    source = ColumnDataSource(data=dict(x=range(analog_stream_0_data.shape[1]), y=analog_stream_0_data[electrode_index]))
    
    bfig = figure(plot_width=900, plot_height=400, title='Voltage Activity (volt) / Time (microsecond)')
    bfig.circle('x', 'y', source=source)
    bfig.xaxis.axis_label = 'Microsecond'
    bfig.yaxis.axis_label = 'Voltage'
    bfig.ygrid.minor_grid_line_color = 'navy'
    bfig.ygrid.minor_grid_line_alpha = 0.1
    

    def callback(attr, old, new):
        source.data = ColumnDataSource(data=dict(x=range(analog_stream_0_data.shape[1]), y=analog_stream_0_data[new-1])).data

    slider = Slider(start=1, end=electrodes, value=1, step=1, title="electrode")
    slider.on_change('value', callback)

    doc.add_root(column(slider, bfig))


show(modify_doc) #show(modify_doc, notebook_url="localhost:8890") 

Detecting Spikes

The detection of spikes will follow the formula:

spike = abs(voltage) > abs(std_electrode_voltage*factor)

You can define the factor below (using 5 as default)

In [ ]:
def aux(value, std, factor):
    return False if abs(value) < abs(std*factor) else True

def threshold(electrode_values, factor=5):
    std = np.std(electrode_values)
    return np.array([aux(value, std, factor) for value in electrode_values])

start = time.time()    
np_analog_stream_0_data_spikes = np.array([threshold(electrode_values, 5) for electrode_values in analog_stream_0_data])
end = time.time()
print("Elapsed time: " + str(end-start))
print(np_analog_stream_0_data_spikes)
In [6]:
def modify_doc(doc):
    electrode_index = 0
    colors_spikes = ['green' if spike else 'blue' for spike in np_analog_stream_0_data_spikes[electrode_index]]
    source = ColumnDataSource(data=dict(x=range(analog_stream_0_data.shape[1]), y=analog_stream_0_data[electrode_index], color=colors_spikes))
    bfig = figure(plot_width=900, plot_height=400, title='Voltage Activity (volt) / Time (microsecond)')
    bfig.circle(x='x', y='y', source=source, color='color')
    bfig.xaxis.axis_label = 'Microsecond'
    bfig.yaxis.axis_label = 'Voltage'
    bfig.ygrid.minor_grid_line_color = 'navy'
    bfig.ygrid.minor_grid_line_alpha = 0.1
    

    def callback(attr, old, new):
        colors_spikes = ['green' if spike else 'blue' for spike in np_analog_stream_0_data_spikes[new-1]]
        source.data = ColumnDataSource(data=dict(x=range(analog_stream_0_data.shape[1]), y=analog_stream_0_data[new-1], color=colors_spikes)).data


    slider = Slider(start=1, end=electrodes, value=1, step=1, title="electrode")
    slider.on_change('value', callback)

    doc.add_root(column(slider, bfig))


show(modify_doc) #show(modify_doc, notebook_url="localhost:8890") 

Detecting Seizures + Raster Plot

The detection of seizures will follow the formula:

seizures = #electrodes spiking >= factor * electrodes

You can define the factor below (using 2.5% as default)

In [7]:
spike_times=[]
electrode_spike=[]
seizure_dict={}
for electrode, line in enumerate(np_analog_stream_0_data_spikes):
    for time, val in enumerate(line):
        if val:
            spike_times.append(time)
            electrode_spike.append(electrode)
            set_electrodes = seizure_dict.get(time)
            if set_electrodes is not None:
                set_electrodes.add(electrode)
                seizure_dict[time] = set_electrodes
            else:
                seizure_dict[time] = {electrode}
                
factor=0.05
threshold = electrodes * factor
seizure_detection = np.zeros(np_analog_stream_0_data_spikes.shape[1], dtype=bool)
for time in seizure_dict:
    spikes = len(seizure_dict[time])
    if spikes > threshold:
        seizure_detection[time]=True


colors = ['red' if seizure_detection[t] else 'blue' for t in spike_times]
In [8]:
plot_size_and_tools = {'plot_height': 400, 'plot_width': 900,
                       'x_range': [0,np_analog_stream_0_data_spikes.shape[1]], 'y_range': [0, np_analog_stream_0_data_spikes.shape[0]]}

raster = figure(title="Raster Plot", **plot_size_and_tools)
raster.circle(x=spike_times, y=electrode_spike, color=colors)

show(raster) #show(raster, notebook_url="localhost:8890") 

Neural Network

Build a 2-hidden layers fully connected neural network (a.k.a multilayer perceptron) with TensorFlow. We are using some of TensorFlow higher-level wrappers (tf.estimators, tf.layers, tf.metrics, ...)

Data Manipulation

You can define the prediction time below with the shift attribute (unit milisecond, using 100ms by default)

In [9]:
data = normalize(analog_stream_0_data[:].transpose().astype(np.float32), norm='max')
seizure_prediction_labels = seizure_detection

MS_TO_US = 1000
shift = 100
            
for index,_ in enumerate(seizure_prediction_labels):
    if seizure_prediction_labels[index]:
        predict = index - shift * MS_TO_US
        if predict > 0:
            seizure_prediction_labels[predict] = 1
            pre_seizure = predict + 1
            while pre_seizure < index:
                seizure_prediction_labels[pre_seizure] = 1
                pre_seizure +=1
            i = index+1
            while seizure_prediction_labels[i]:
                seizure_prediction_labels[i] = 0
                i+=1
        else:
            seizure_prediction_labels[index] = 0
    else:
        seizure_prediction_labels[index]=0

data_train, data_test, labels_train, labels_test = train_test_split(data, seizure_prediction_labels, test_size=0.20, random_state=42)
unique, counts = np.unique(seizure_prediction_labels, return_counts=True)
ratio_dict = dict(zip(unique, counts))
try:
    ratio_dict[True]
except KeyError:
    ratio_dict[True] = 0
try:
    ratio_dict[False]
except KeyError:
    ratio_dict[False] = 0
print(ratio_dict)
{False: 2614931, True: 385069}

Configs

In [27]:
#Parameters:
learning_rate = 0.1
num_steps = 1000
batch_size = 128
display_step = 100

#Network Parameters:
n_hidden_1 = 256 # 1st layer number of neurons - not related with the electrodes just a common number
n_hidden_2 = 256 # 2nd layer number of neurons
num_input = 256 # data input (records shape: 256)
num_classes = 2 # total classes (pre-seizure/non-pre-seizure)

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'data': data_train}, y=labels_train,
    batch_size=batch_size, num_epochs=None, shuffle=True)
 
# Define the neural network
def neural_net(x_dict):
    # TF Estimator input is a dict, in case of multiple inputs
    x = x_dict['data']
    # Hidden fully connected layer with 256 neurons
    layer_1 = tf.layers.dense(x, n_hidden_1)
    # Hidden fully connected layer with 256 neurons
    layer_2 = tf.layers.dense(layer_1, n_hidden_2)
    # Output fully connected layer with a neuron for each class
    out_layer = tf.layers.dense(layer_2, num_classes)
    return out_layer

# Define the model function (following TF Estimator Template)
def model_fn(features, labels, mode):
    
    # Build the neural network
    logits = neural_net(features)
    
    # Predictions
    pred_classes = tf.argmax(logits, axis=1)
    pred_probas = tf.nn.softmax(logits)
    
    # If prediction mode, early return
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=pred_classes) 
        
    # Define loss and optimizer
    ratio = ratio_dict[True] / (ratio_dict[False] + ratio_dict[True])
    class_weight = tf.constant([ratio, 1.0 - ratio])
    weighted_logits = tf.multiply(logits, class_weight)
    loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=weighted_logits, labels=tf.cast(labels, dtype=tf.int32)))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss_op, global_step=tf.train.get_global_step())
    
    # Evaluate the the model
    acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes)
    fp_op = tf.metrics.false_positives(labels=labels, predictions=pred_classes)
    fn_op = tf.metrics.false_negatives(labels=labels, predictions=pred_classes)
    tp_op = tf.metrics.true_positives(labels=labels, predictions=pred_classes)
    tn_op = tf.metrics.true_negatives(labels=labels, predictions=pred_classes)
    f1_op = tf.contrib.metrics.f1_score(labels=labels, predictions=pred_classes)
    
    # TF Estimators requires to return a EstimatorSpec, that specify
    # the different ops for training, evaluating, ...
    estim_specs = tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=pred_classes,
      loss=loss_op,
      train_op=train_op,
      eval_metric_ops={'accuracy': acc_op,
                       'false_positives': fp_op,
                       'false_negatives': fn_op,
                       'true_positives': tp_op,
                       'true_negatives': tn_op,
                       'f1_score': f1_op,
                       })

    return estim_specs

# Build the Estimator
model = tf.estimator.Estimator(model_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3t24ftn8
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3t24ftn8
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3t24ftn8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fbd6d97bfd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3t24ftn8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fbd6d97bfd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Train

In [28]:
# Train the Model
model.train(input_fn, steps=num_steps)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3t24ftn8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3t24ftn8/model.ckpt.
INFO:tensorflow:loss = 0.71053183, step = 1
INFO:tensorflow:loss = 0.71053183, step = 1
INFO:tensorflow:global_step/sec: 144.841
INFO:tensorflow:global_step/sec: 144.841
INFO:tensorflow:loss = 0.35613683, step = 101 (0.698 sec)
INFO:tensorflow:loss = 0.35613683, step = 101 (0.698 sec)
INFO:tensorflow:global_step/sec: 244.618
INFO:tensorflow:global_step/sec: 244.618
INFO:tensorflow:loss = 0.4892766, step = 201 (0.406 sec)
INFO:tensorflow:loss = 0.4892766, step = 201 (0.406 sec)
INFO:tensorflow:global_step/sec: 178.78
INFO:tensorflow:global_step/sec: 178.78
INFO:tensorflow:loss = 0.3325999, step = 301 (0.554 sec)
INFO:tensorflow:loss = 0.3325999, step = 301 (0.554 sec)
INFO:tensorflow:global_step/sec: 200.161
INFO:tensorflow:global_step/sec: 200.161
INFO:tensorflow:loss = 0.4307215, step = 401 (0.500 sec)
INFO:tensorflow:loss = 0.4307215, step = 401 (0.500 sec)
INFO:tensorflow:global_step/sec: 190.601
INFO:tensorflow:global_step/sec: 190.601
INFO:tensorflow:loss = 0.37734658, step = 501 (0.528 sec)
INFO:tensorflow:loss = 0.37734658, step = 501 (0.528 sec)
INFO:tensorflow:global_step/sec: 199.82
INFO:tensorflow:global_step/sec: 199.82
INFO:tensorflow:loss = 0.3582129, step = 601 (0.499 sec)
INFO:tensorflow:loss = 0.3582129, step = 601 (0.499 sec)
INFO:tensorflow:global_step/sec: 201.16
INFO:tensorflow:global_step/sec: 201.16
INFO:tensorflow:loss = 0.5170686, step = 701 (0.503 sec)
INFO:tensorflow:loss = 0.5170686, step = 701 (0.503 sec)
INFO:tensorflow:global_step/sec: 203.339
INFO:tensorflow:global_step/sec: 203.339
INFO:tensorflow:loss = 0.29608124, step = 801 (0.486 sec)
INFO:tensorflow:loss = 0.29608124, step = 801 (0.486 sec)
INFO:tensorflow:global_step/sec: 199.673
INFO:tensorflow:global_step/sec: 199.673
INFO:tensorflow:loss = 0.36291677, step = 901 (0.500 sec)
INFO:tensorflow:loss = 0.36291677, step = 901 (0.500 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmp3t24ftn8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmp3t24ftn8/model.ckpt.
INFO:tensorflow:Loss for final step: 0.38857606.
INFO:tensorflow:Loss for final step: 0.38857606.
Out[28]:
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fbd6d825588>

Evaluate

In [29]:
# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'data': data_test}, y=labels_test,
    batch_size=batch_size, shuffle=False)
# Use the Estimator 'evaluate' method
model_eval = model.evaluate(input_fn)

model_eval['specificity'] = model_eval['true_negatives']/ratio_dict[False]
model_eval['sensitivity'] = model_eval['true_positives']/ratio_dict[True]
print(model_eval)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-06-09T21:13:40Z
INFO:tensorflow:Starting evaluation at 2019-06-09T21:13:40Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp3t24ftn8/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmp/tmp3t24ftn8/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2019-06-09-21:14:01
INFO:tensorflow:Finished evaluation at 2019-06-09-21:14:01
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.872285, f1_score = 0.22751863, false_negatives = 76607.0, false_positives = 22.0, global_step = 1000, loss = 0.38276878, true_negatives = 522961.0, true_positives = 410.0
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.872285, f1_score = 0.22751863, false_negatives = 76607.0, false_positives = 22.0, global_step = 1000, loss = 0.38276878, true_negatives = 522961.0, true_positives = 410.0
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmp/tmp3t24ftn8/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmp/tmp3t24ftn8/model.ckpt-1000
{'accuracy': 0.872285, 'f1_score': 0.22751863, 'false_negatives': 76607.0, 'false_positives': 22.0, 'loss': 0.38276878, 'true_negatives': 522961.0, 'true_positives': 410.0, 'global_step': 1000, 'specificity': 0.1999903630344357, 'sensitivity': 0.0010647442406425862}

Predict

In [30]:
# Predict single images
n_data = 4
# Get images from test set
test_images = data_test[:n_data]
# Prepare the input data
input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'data': test_images}, shuffle=False)
# Use the model to predict the images class
preds = list(model.predict(input_fn))

# Display
for i in range(n_data):
    print("Model prediction:", preds[i])
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp3t24ftn8/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmp/tmp3t24ftn8/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
Model prediction: 0
Model prediction: 0
Model prediction: 0
Model prediction: 0
In [ ]: